import json
from collections import defaultdict, Counter
import argparse
import os
import sys
from glob import glob
import copy
from tqdm import tqdm
from datasets import load_dataset
import random

sys.set_int_max_str_digits(0)
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))


# from post_processors.code.clean import tag_cleaner


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--critique_exec_file", type=str,
                        help="The file contains completion from the teacher model for critique, as well as the execution results."
                             "The inputs for this file are generated by `pp_critique_difficulty` script.")
    parser.add_argument("--completion_file", type=str,
                        help="The file contains the completion for each query.")
    parser.add_argument("--completion_response_field", type=str, default="completion")
    parser.add_argument("--completion_problem_id_field", type=str, default="problem_id")
    parser.add_argument("--prompt_file", type=str, default="prompts/apps/worsen_from_feedback_0shot_v1.0.txt")
    parser.add_argument("--output_file", type=str)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--sample_num", type=int, default=2000)
    parser.add_argument("--split", type=str, default="train")
    args = parser.parse_args()

    random.seed(args.seed)

    prompt_template = open(args.prompt_file).read()

    _dataset = load_dataset("codeparrot/apps", split=args.split).to_list()
    p_id2item = {item["problem_id"]: item for item in _dataset}

    critiques = json.load(open(args.critique_exec_file))
    correct_critiques = []
    for item in critiques:
        if isinstance(item["completion"], str) or isinstance(item["completion"], dict):
            completions = [item["completion"]]
        else:
            completions = item["completion"]

        preds = item["pred"]

        for i, (comp, p, r) in enumerate(zip(completions, preds, item["res"])):
            if r:
                new_item = copy.deepcopy(item)
                new_item["completion"] = comp
                new_item["pred"] = p
                new_item["res"] = r
                correct_critiques.append(new_item)

    print(f"Total correct critiques: {len(correct_critiques)}")

    if os.path.exists(args.completion_file):
        data = json.load(open(args.completion_file))
    else:
        data = []
        for file in glob(args.completion_file):
            data += json.load(open(file))
    outputs = []
    for item in tqdm(data):
        problem_id = item[args.completion_problem_id_field]
        critique = random.choice(correct_critiques)
        while critique["problem_id"] == problem_id:
            critique = random.choice(correct_critiques)

        preds = item["pred"]
        if not preds:
            continue
        if isinstance(preds, str):
            preds = [preds]

        for i, pred in enumerate(preds):
            if pred:
                prompt = prompt_template.format(
                    example_question=critique["question"],
                    example_code=critique["neg_code"],
                    feedback=critique["completion"]["feedback"],
                    corrected_program=critique["completion"]["corrected_program"],
                    question=p_id2item[problem_id]["question"],
                    code=pred,
                )
                new_item = copy.deepcopy(item)
                new_item["id"] = f"{item[args.completion_problem_id_field]}_neg{i}"
                new_item["prompt"] = prompt
                if args.completion_response_field != "response":
                    new_item["response"] = new_item.pop(args.completion_response_field)
                outputs.append(new_item)

                if len(outputs) >= args.sample_num:
                    break
        if len(outputs) >= args.sample_num:
            break

    print(f"Total number of items: {len(outputs)}")
    # json.dump(outputs, open(args.output_file, "w"), indent=2)
    with open(args.output_file, "w") as f:
        for item in outputs:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")


if __name__ == '__main__':
    main()

"""
>>> python scripts/apps/pp_worsen_inputs.py --critique_exec_file outputs/apps/critique/apps.train.gpt4o.tem1.0.n11.neg.intro.inter.gpt4o.tem1.0.s42.n1.json_obj.exec.json \
    --completion_file "../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.0-of-8.v1.1.json" \
    --completion_response_field response \
    --output_file ../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.v1.1.worsen_4o_critic.s42.f2000.jsonl \
    --sample_num 2000 --seed 42 --completion_problem_id_field id
    
>>> python scripts/apps/pp_worsen_inputs.py --critique_exec_file outputs/apps/critique/apps.train.gpt4o.tem1.0.n11.neg.intro.inter.gpt4o.tem1.0.s42.n1.json_obj.exec.json \
    --completion_file "../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.?-of-8.v1.1.json" \
    --completion_response_field response \
    --output_file ../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.v1.1.worsen_4o_critic.s42.f10k.jsonl \
    --sample_num 10000 --seed 42 --completion_problem_id_field id

>>> python scripts/apps/pp_worsen_inputs.py --critique_exec_file outputs/apps/critique/apps.train.gpt4o.tem1.0.n11.neg.intro.inter.gpt4o.tem1.0.s42.n1.json_obj.exec.json \
    --completion_file "../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.?-of-8.v1.1.json" \
    --completion_response_field response 
    --output_file ../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.v1.1.worsen_4o_critic.s42.f100k.jsonl \
    --sample_num 100000 --seed 42 --completion_problem_id_field id

Total correct critiques: 7205
100%|██████████████████████████████████████████████████████████████████| 5000/5000 [00:04<00:00, 1070.68it/s]
Total number of items: 49868

>>> python scripts/apps/pp_worsen_inputs.py --critique_exec_file outputs/apps/critique/apps.train.gpt4o.tem1.0.n11.neg.intro.inter.gpt4o.tem1.0.s42.n1.json_obj.exec.json \
    --completion_file "../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.?-of-8.v1.1.json" \
    --completion_response_field response \
    --output_file ../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.v1.1.worsen_4o.s42.f10k.jsonl \
    --sample_num 10000 --seed 42 --completion_problem_id_field id --prompt_file prompts/apps/worsen_0shot_v1.0.txt 
    
     python azure/gpt_crawler_mp.py --prompt_file ../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.v1.1.worsen_4o.s42.f10k.jsonl \
     --outfile ../gpt-chat-examples/outputs/apps/critique/r2c.sft.train.0shot.tem1.0.n10.v1.1.worsen_4o.s42.f10k.gpt4o.tem1.0.s42.n1.json_obj.jsonl \
     --model gpt-4o --max_gen_tokens 4096 --temperature 1.0 --num_processes 24 --seed 42 --n 1 --response_format json_object

"""

